import cv2
import gym
import highway_env
from stable_baselines3 import PPO
from sb3_contrib import TRPO
import torch


# situation = "intersection-v0"
situation = "racetrack-v1"


frameSize = (1280,560)
# out = cv2.VideoWriter('video'+situation+'.avi',cv2.VideoWriter_fourcc(*'DIVX'), 16, frameSize)
out = cv2.VideoWriter('video'+situation+'.avi', cv2.VideoWriter_fourcc(*'mp4v'), 4, frameSize)


env = gym.make(situation)
env.configure({
    "screen_width": 1280,
    "screen_height": 560,
    "renderfps": 16
})

env.reset()
n_cpu = 6
batch_size = 64
model = TRPO("MlpPolicy", env,
             learning_rate=0.001,
             n_steps=2048,
             batch_size=128,
             gamma=0.99,
             cg_max_steps=15,
             cg_damping=0.1,
             line_search_shrinking_factor=0.8,
             line_search_max_iter=10,
             n_critic_updates=10,
             gae_lambda=0.95,
             use_sde=False,
             sde_sample_freq=-1,
             normalize_advantage=True,
             target_kl=0.01,
             sub_sampling_factor=1,
             tensorboard_log=None,
             policy_kwargs=None,
             verbose=0,
             seed=None,
             device='auto',
             _init_setup_model=True)

# uncomment the lines below if you want to train a new model

model.learn(int(1e3))
model.save('situation'+'_trpo/model')

print()
print("Done Learning!!")
print()


########## Load and test saved model##############
model = TRPO.load('situation'+'_trpo/model')
#while True:
number_of_collisions = 0
for f in range(40):
  done = truncated = False
  obs, info = env.reset()
  while not (done or truncated):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)  # env.step(action.item(0))
    print(action)
    print(obs)
    print(info)
    if info.get('crashed'):
        number_of_collisions += 1
    env.render()
    cur_frame = env.render(mode="rgb_array")
    out.write(cur_frame)

out.release()
print('number_of_collisions is:', number_of_collisions)
print('DONE')